package org.springblade.thingsphere.network.protocol.impl;

import lombok.extern.slf4j.Slf4j;
import org.springblade.core.log.exception.ServiceException;
import org.springblade.thingsphere.CompositeProtocolSupport;
import org.springblade.thingsphere.ProtocolSupportProvider;
import org.springblade.thingsphere.network.protocol.ProtocolAsset;
import org.springblade.thingsphere.network.protocol.ProtocolClassLoader;
import org.springblade.thingsphere.network.protocol.ProtocolSupportDefinition;
import org.springframework.stereotype.Component;

import java.io.File;
import java.net.URL;
import java.net.URLClassLoader;
import java.util.Map;
import java.util.ServiceLoader;
import java.util.concurrent.ConcurrentHashMap;

/**
 * @author lhb
 * @date 2024/9/26 下午6:28
 * 协议包初始化
 */
@Slf4j
@Component
public class ProtocolAssetSupplier implements ProtocolAsset {

	private final static Map<Long, ProtocolClassLoader> protocolLoaders = new ConcurrentHashMap<>();
	private final static Map<String, ProtocolSupportProvider> loaded = new ConcurrentHashMap<>();

	private final static Map<Long, CompositeProtocolSupport> protocolSupport = new ConcurrentHashMap<>();

	public Map<Long, CompositeProtocolSupport> getProtocolSupport() {
		return protocolSupport;
	}

	public CompositeProtocolSupport getProtocolSupportById(Long id) {
		return protocolSupport.get(id);
	}

	/**
	 * 动态加载jar
	 *
	 * @param definition
	 * @return
	 */
	@Override
	public CompositeProtocolSupport load(ProtocolSupportDefinition definition) {

		if (protocolSupport.containsKey(definition.getId())) {
			return protocolSupport.get(definition.getId());
		}

		try {
			String location = definition.getJarPath();
			String provider = definition.getClazzName();
			URL url;
			if (!location.contains("://")) {
				url = (new File(location)).toURI().toURL();
			} else {
				url = new URL("jar:" + location + "!/");
			}

			URL fLocation = url;
			ProtocolSupportProvider supportProvider = (ProtocolSupportProvider) loaded.remove(provider);
			if (null != supportProvider) {
				supportProvider.dispose();
			}

			ProtocolClassLoader loader = (ProtocolClassLoader) protocolLoaders.compute(definition.getId(), (key, old) -> {
				if (null != old) {
					try {
						this.closeLoader(old);
					} catch (Exception var5) {
						log.error("加载jar包失败", var5);
					}
				}
				return this.createClassLoader(fLocation);
			});

			log.debug("load protocol support from : {}", location);
			if (provider != null) {
				supportProvider = (ProtocolSupportProvider) Class.forName(provider, true, loader).newInstance();
			} else {
				supportProvider = (ProtocolSupportProvider) ServiceLoader.load(ProtocolSupportProvider.class, loader).iterator().next();
			}

			CompositeProtocolSupport compositeProtocolSupport = supportProvider.create();

			loaded.put(provider, supportProvider);
			protocolSupport.put(definition.getId(), compositeProtocolSupport);
			return compositeProtocolSupport;
		} catch (Exception e) {
			throw new ServiceException(e.getMessage());
		}
	}

	protected void closeLoader(ProtocolClassLoader loader) {
		try {
			loader.close();
		} catch (Throwable var3) {
			throw new RuntimeException(var3);
		}
	}

	protected ProtocolClassLoader createClassLoader(URL location) {
		return new ProtocolClassLoader(new URL[]{location}, this.getClass().getClassLoader());
	}


	// 卸载jar
	@Override
	public boolean unLoadJar(ProtocolSupportDefinition definition) {
		try {
			URLClassLoader classLoader = protocolLoaders.get(definition.getId());
			if (classLoader != null) {
				classLoader.close();
			}
			loaded.remove(definition.getClazzName());
			protocolLoaders.remove(definition.getId());
			protocolSupport.remove(definition.getId());
			return true;
		} catch (Exception e) {
			log.error(e.getMessage(), e);
			return false;
		}
	}

}
